{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## loading Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from PIL import Image\n",
    "path2content= \"./data/content.jpg\"\n",
    "path2style= \"./data/style.jpg\"\n",
    "content_img = Image.open(path2content)\n",
    "style_img = Image.open(path2style)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "content_img"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "style_img"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchvision.transforms as transforms\n",
    "\n",
    "h, w = 256, 384 \n",
    "mean_rgb = (0.485, 0.456, 0.406)\n",
    "std_rgb = (0.229, 0.224, 0.225)\n",
    "\n",
    "transformer = transforms.Compose([\n",
    "                    transforms.Resize((h,w)),  \n",
    "                    transforms.ToTensor(),\n",
    "                    transforms.Normalize(mean_rgb, std_rgb)])  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "content_tensor = transformer(content_img)\n",
    "print(content_tensor.shape, content_tensor.requires_grad)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "style_tensor = transformer(style_img)\n",
    "print(style_tensor.shape, style_tensor.requires_grad)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_tensor = content_tensor.clone().requires_grad_(True)\n",
    "print(input_tensor.shape, input_tensor.requires_grad)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "def imgtensor2pil(img_tensor):\n",
    "    img_tensor_c = img_tensor.clone().detach()\n",
    "    img_tensor_c*=torch.tensor(std_rgb).view(3, 1,1)\n",
    "    img_tensor_c+=torch.tensor(mean_rgb).view(3, 1,1)\n",
    "    img_tensor_c = img_tensor_c.clamp(0,1)\n",
    "    img_pil=to_pil_image(img_tensor_c)\n",
    "    return img_pil"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pylab as plt\n",
    "%matplotlib inline\n",
    "from torchvision.transforms.functional import to_pil_image\n",
    "\n",
    "plt.imshow(imgtensor2pil(content_tensor))\n",
    "plt.title(\"content image\");"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(imgtensor2pil(style_tensor))\n",
    "plt.title(\"style image\");"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchvision.models as models\n",
    "\n",
    "device = torch.device(\"cuda:3\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "model_vgg = models.vgg19(pretrained=True).features.to(device).eval()\n",
    "for param in model_vgg.parameters():\n",
    "    param.requires_grad_(False)   \n",
    "print(model_vgg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_features(x, model, layers):\n",
    "    features = {}\n",
    "    for name, layer in enumerate(model.children()):\n",
    "        x = layer(x)\n",
    "        if str(name) in layers:\n",
    "            features[layers[str(name)]] = x\n",
    "    return features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def gram_matrix(x):\n",
    "    n, c, h, w = x.size()\n",
    "    x = x.view(n*c, h * w)\n",
    "    gram = torch.mm(x, x.t())\n",
    "    return gram"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn.functional as F\n",
    "\n",
    "def get_content_loss(pred_features, target_features, layer):\n",
    "    target= target_features[layer]\n",
    "    pred = pred_features [layer]\n",
    "    loss = F.mse_loss(pred, target)\n",
    "    return loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_style_loss(pred_features, target_features, style_layers_dict):  \n",
    "    loss = 0\n",
    "    for layer in style_layers_dict:\n",
    "        pred_fea = pred_features[layer]\n",
    "        pred_gram = gram_matrix(pred_fea)\n",
    "        n, c, h, w = pred_fea.shape\n",
    "        target_gram = gram_matrix (target_features[layer])\n",
    "        layer_loss = style_layers_dict[layer] *  F.mse_loss(pred_gram, target_gram)\n",
    "        loss += layer_loss/ (n* c * h * w)\n",
    "    return loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "feature_layers = {'0': 'conv1_1',\n",
    "                  '5': 'conv2_1',\n",
    "                  '10': 'conv3_1',\n",
    "                  '19': 'conv4_1',\n",
    "                  '21': 'conv4_2',  \n",
    "                  '28': 'conv5_1'}\n",
    "\n",
    "con_tensor = content_tensor.unsqueeze(0).to(device)\n",
    "sty_tensor = style_tensor.unsqueeze(0).to(device)\n",
    "\n",
    "content_features = get_features(con_tensor, model_vgg, feature_layers)\n",
    "style_features = get_features(sty_tensor, model_vgg, feature_layers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for key in content_features.keys():\n",
    "    print(content_features[key].shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch import optim\n",
    "\n",
    "input_tensor = con_tensor.clone().requires_grad_(True)\n",
    "optimizer = optim.Adam([input_tensor], lr=0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_epochs = 300\n",
    "content_weight = 1e1\n",
    "style_weight = 1e4\n",
    "content_layer = \"conv5_1\"\n",
    "style_layers_dict = { 'conv1_1': 0.75,\n",
    "                      'conv2_1': 0.5,\n",
    "                      'conv3_1': 0.25,\n",
    "                      'conv4_1': 0.25,\n",
    "                      'conv5_1': 0.25}\n",
    "\n",
    "for epoch in range(num_epochs+1):\n",
    "    optimizer.zero_grad()\n",
    "    input_features = get_features(input_tensor, model_vgg, feature_layers)\n",
    "    content_loss = get_content_loss (input_features, content_features, content_layer)\n",
    "    style_loss = get_style_loss(input_features, style_features, style_layers_dict)\n",
    "    neural_loss = content_weight * content_loss + style_weight * style_loss\n",
    "    neural_loss.backward(retain_graph=True)\n",
    "    optimizer.step()\n",
    "    \n",
    "    if epoch % 100 == 0:\n",
    "        print('epoch {}, content loss: {:.2}, style loss {:.2}'.format(\n",
    "          epoch,content_loss, style_loss))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(imgtensor2pil(input_tensor[0].cpu()));"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
